{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Advanced usage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows some more advanced features of `skorch`. More examples will be added with time.\n",
    "\n",
    "<table align=\"left\"><td>\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb\">\n",
    "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>  \n",
    "</td><td>\n",
    "<a target=\"_blank\" href=\"https://github.com/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table of contents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* [Setup](#Setup)\n",
    "* [Callbacks](#Callbacks)\n",
    "  * [Writing your own callback](#Writing-a-custom-callback)\n",
    "  * [Accessing callback parameters](#Accessing-callback-parameters)\n",
    "* [Working with different data types](#Working-with-different-data-types)\n",
    "  * [Working with datasets](#Working-with-Datasets)\n",
    "  * [Working with dicts](#Working-with-dicts)\n",
    "* [Multiple return values](#Multiple-return-values-from-forward)\n",
    "  * [Implementing a simple autoencoder](#Implementing-a-simple-autoencoder)\n",
    "  * [Training the autoencoder](#Training-the-autoencoder)\n",
    "  * [Extracting the decoder and the encoder output](#Extracting-the-decoder-and-the-encoder-output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "! [ ! -z \"$COLAB_GPU\" ] && pip install torch skorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "torch.manual_seed(0);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A toy binary classification task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load a toy classification task from `sklearn`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import make_classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = make_classification(1000, 20, n_informative=10, random_state=0)\n",
    "X = X.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1000, 20), (1000,), 0.5)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape, y.shape, y.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definition of the `pytorch` classification `module`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define a vanilla neural network with two hidden layers. The output layer should have 2 output units since there are two classes. In addition, it should have a softmax nonlinearity, because later, when calling `predict_proba`, the output from the `forward` call will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch import NeuralNetClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierModule(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            num_units=10,\n",
    "            nonlin=F.relu,\n",
    "            dropout=0.5,\n",
    "    ):\n",
    "        super(ClassifierModule, self).__init__()\n",
    "        self.num_units = num_units\n",
    "        self.nonlin = nonlin\n",
    "        self.dropout = dropout\n",
    "\n",
    "        self.dense0 = nn.Linear(20, num_units)\n",
    "        self.nonlin = nonlin\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.dense1 = nn.Linear(num_units, 10)\n",
    "        self.output = nn.Linear(10, 2)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = self.nonlin(self.dense0(X))\n",
    "        X = self.dropout(X)\n",
    "        X = F.relu(self.dense1(X))\n",
    "        X = F.softmax(self.output(X), dim=-1)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Callbacks are a powerful and flexible way to customize the behavior of your neural network. They are all called at specific points during the model training, e.g. when training starts, or after each batch. Have a look at the `skorch.callbacks` module to see the callbacks that are already implemented."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Writing a custom callback"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Although `skorch` comes with a handful of useful callbacks, you may find that you would like to write your own callbacks. Doing so is straightforward, just remember these rules:\n",
    "* They should inherit from `skorch.callbacks.Callback`.\n",
    "* They should implement at least one of the `on_`-methods provided by the parent class (e.g. `on_batch_begin` or `on_epoch_end`).\n",
    "* As argument, the `on_`-methods first get the `NeuralNet` instance, and, where appropriate, the local data (e.g. the data from the current batch). The method should also have `**kwargs` in the signature for potentially unused arguments.\n",
    "* *Optional*: If you have attributes that should be reset when the model is re-initialized, those attributes should be set in the `initialize` method."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is an example of a callback that remembers at which epoch the validation accuracy reached a certain value. Then, when training is finished, it calls a mock Twitter API and tweets that epoch. We proceed as follows:\n",
    "* We set the desired minimum accuracy during `__init__`.\n",
    "* We set the critical epoch during `initialize`.\n",
    "* After each epoch, if the critical accuracy has not yet been reached, we check if it was reached.\n",
    "* When training finishes, we send a tweet informing us whether our training was successful or not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch.callbacks import Callback\n",
    "\n",
    "\n",
    "def tweet(msg):\n",
    "    print(\"~\" * 60)\n",
    "    print(\"*tweet*\", msg, \"#skorch #pytorch\")\n",
    "    print(\"~\" * 60)\n",
    "\n",
    "\n",
    "class AccuracyTweet(Callback):\n",
    "    def __init__(self, min_accuracy):\n",
    "        self.min_accuracy = min_accuracy\n",
    "\n",
    "    def initialize(self):\n",
    "        self.critical_epoch_ = -1\n",
    "\n",
    "    def on_epoch_end(self, net, **kwargs):\n",
    "        if self.critical_epoch_ > -1:\n",
    "            return\n",
    "        # look at the validation accuracy of the last epoch\n",
    "        if net.history[-1, 'valid_acc'] >= self.min_accuracy:\n",
    "            self.critical_epoch_ = len(net.history)\n",
    "\n",
    "    def on_train_end(self, net, **kwargs):\n",
    "        if self.critical_epoch_ < 0:\n",
    "            msg = \"Accuracy never reached {} :(\".format(self.min_accuracy)\n",
    "        else:\n",
    "            msg = \"Accuracy reached {} at epoch {}!!!\".format(\n",
    "                self.min_accuracy, self.critical_epoch_)\n",
    "\n",
    "        tweet(msg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we initialize a `NeuralNetClassifier` and pass your new callback in a list to the `callbacks` argument. After that, we train the model and see what happens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=15,\n",
    "    lr=0.02,\n",
    "    warm_start=True,\n",
    "    callbacks=[AccuracyTweet(min_accuracy=0.7)],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.6954\u001b[0m       \u001b[32m0.6000\u001b[0m        \u001b[35m0.6844\u001b[0m  0.0176\n",
      "      2        \u001b[36m0.6802\u001b[0m       0.5950        \u001b[35m0.6817\u001b[0m  0.0150\n",
      "      3        0.6839       0.6000        \u001b[35m0.6792\u001b[0m  0.0178\n",
      "      4        \u001b[36m0.6753\u001b[0m       0.5900        \u001b[35m0.6767\u001b[0m  0.0140\n",
      "      5        0.6769       0.5950        \u001b[35m0.6742\u001b[0m  0.0172\n",
      "      6        0.6774       \u001b[32m0.6050\u001b[0m        \u001b[35m0.6720\u001b[0m  0.0166\n",
      "      7        \u001b[36m0.6693\u001b[0m       \u001b[32m0.6250\u001b[0m        \u001b[35m0.6695\u001b[0m  0.0134\n",
      "      8        0.6694       \u001b[32m0.6300\u001b[0m        \u001b[35m0.6672\u001b[0m  0.0168\n",
      "      9        0.6703       \u001b[32m0.6400\u001b[0m        \u001b[35m0.6652\u001b[0m  0.0177\n",
      "     10        \u001b[36m0.6523\u001b[0m       \u001b[32m0.6550\u001b[0m        \u001b[35m0.6623\u001b[0m  0.0151\n",
      "     11        0.6641       \u001b[32m0.6650\u001b[0m        \u001b[35m0.6603\u001b[0m  0.0134\n",
      "     12        0.6524       0.6650        \u001b[35m0.6582\u001b[0m  0.0138\n",
      "     13        \u001b[36m0.6506\u001b[0m       \u001b[32m0.6700\u001b[0m        \u001b[35m0.6553\u001b[0m  0.0126\n",
      "     14        \u001b[36m0.6489\u001b[0m       0.6650        \u001b[35m0.6527\u001b[0m  0.0132\n",
      "     15        0.6505       \u001b[32m0.6750\u001b[0m        \u001b[35m0.6502\u001b[0m  0.0133\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "*tweet* Accuracy never reached 0.7 :( #skorch #pytorch\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Oh no, our model never reached a validation accuracy of 0.7. Let's train some more (this is possible because we set `warm_start=True`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     16        \u001b[36m0.6473\u001b[0m       0.6750        \u001b[35m0.6474\u001b[0m  0.0175\n",
      "     17        \u001b[36m0.6431\u001b[0m       \u001b[32m0.6800\u001b[0m        \u001b[35m0.6443\u001b[0m  0.0185\n",
      "     18        0.6461       \u001b[32m0.6900\u001b[0m        \u001b[35m0.6418\u001b[0m  0.0162\n",
      "     19        \u001b[36m0.6430\u001b[0m       0.6850        \u001b[35m0.6392\u001b[0m  0.0131\n",
      "     20        \u001b[36m0.6364\u001b[0m       \u001b[32m0.6950\u001b[0m        \u001b[35m0.6366\u001b[0m  0.0146\n",
      "     21        \u001b[36m0.6266\u001b[0m       \u001b[32m0.7000\u001b[0m        \u001b[35m0.6334\u001b[0m  0.0149\n",
      "     22        0.6316       0.7000        \u001b[35m0.6308\u001b[0m  0.0151\n",
      "     23        \u001b[36m0.6231\u001b[0m       0.7000        \u001b[35m0.6277\u001b[0m  0.0128\n",
      "     24        \u001b[36m0.6094\u001b[0m       0.7000        \u001b[35m0.6242\u001b[0m  0.0160\n",
      "     25        0.6250       \u001b[32m0.7050\u001b[0m        \u001b[35m0.6215\u001b[0m  0.0130\n",
      "     26        0.6180       \u001b[32m0.7150\u001b[0m        \u001b[35m0.6187\u001b[0m  0.0139\n",
      "     27        0.6186       0.7150        \u001b[35m0.6159\u001b[0m  0.0169\n",
      "     28        0.6144       0.7150        \u001b[35m0.6134\u001b[0m  0.0171\n",
      "     29        \u001b[36m0.5993\u001b[0m       0.7150        \u001b[35m0.6100\u001b[0m  0.0147\n",
      "     30        \u001b[36m0.5976\u001b[0m       0.7150        \u001b[35m0.6071\u001b[0m  0.0138\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "*tweet* Accuracy reached 0.7 at epoch 21!!! #skorch #pytorch\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert net.history[-1, 'valid_acc'] >= 0.7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, the validation score exceeded 0.7. Hooray!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Accessing callback parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Say you would like to use a learning rate schedule with your neural net, but you don't know what parameters are best for that schedule. Wouldn't it be nice if you could find those parameters with a grid search? With `skorch`, this is possible. Below, we show how to access the parameters of your callbacks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To simplify the access to your callback parameters, it is best if you give your callback a name. This is achieved by passing the `callbacks` parameter a list of *name*, *callback* tuples, such as:\n",
    "\n",
    "    callbacks=[\n",
    "        ('scheduler', LearningRateScheduler)),\n",
    "        ...\n",
    "    ],\n",
    "    \n",
    "This way, you can access your callbacks using the double underscore semantics (as, for instance, in an `sklearn` `Pipeline`):\n",
    "\n",
    "    callbacks__scheduler__epoch=50,\n",
    "\n",
    "So if you would like to perform a grid search on, say, the number of units in the hidden layer and the learning rate schedule, it could look something like this:\n",
    "\n",
    "    param_grid = {\n",
    "        'module__num_units': [50, 100, 150],\n",
    "        'callbacks__scheduler__epoch': [10, 50, 100],\n",
    "    }\n",
    "    \n",
    "*Note*: If you would like to refresh your knowledge on grid search, look [here](http://scikit-learn.org/stable/modules/grid_search.html#grid-search), [here](http://scikit-learn.org/stable/auto_examples/model_selection/grid_search_text_feature_extraction.html), or in the *Basic_Usage* notebok."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below, we show how accessing the callback parameters works our `AccuracyTweet` callback:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=10,\n",
    "    lr=0.1,\n",
    "    warm_start=True,\n",
    "    callbacks=[\n",
    "        ('tweet', AccuracyTweet(min_accuracy=0.7)),\n",
    "    ],\n",
    "    callbacks__tweet__min_accuracy=0.6,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.6932\u001b[0m       \u001b[32m0.5950\u001b[0m        \u001b[35m0.6749\u001b[0m  0.0161\n",
      "      2        \u001b[36m0.6686\u001b[0m       \u001b[32m0.6550\u001b[0m        \u001b[35m0.6613\u001b[0m  0.0160\n",
      "      3        \u001b[36m0.6641\u001b[0m       0.6450        \u001b[35m0.6487\u001b[0m  0.0168\n",
      "      4        \u001b[36m0.6438\u001b[0m       \u001b[32m0.6600\u001b[0m        \u001b[35m0.6354\u001b[0m  0.0171\n",
      "      5        \u001b[36m0.6293\u001b[0m       \u001b[32m0.7000\u001b[0m        \u001b[35m0.6190\u001b[0m  0.0171\n",
      "      6        \u001b[36m0.6091\u001b[0m       \u001b[32m0.7300\u001b[0m        \u001b[35m0.6040\u001b[0m  0.0147\n",
      "      7        \u001b[36m0.5872\u001b[0m       \u001b[32m0.7500\u001b[0m        \u001b[35m0.5868\u001b[0m  0.0130\n",
      "      8        \u001b[36m0.5820\u001b[0m       \u001b[32m0.7600\u001b[0m        \u001b[35m0.5736\u001b[0m  0.0138\n",
      "      9        \u001b[36m0.5778\u001b[0m       \u001b[32m0.7850\u001b[0m        \u001b[35m0.5595\u001b[0m  0.0129\n",
      "     10        \u001b[36m0.5626\u001b[0m       0.7750        \u001b[35m0.5484\u001b[0m  0.0131\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "*tweet* Accuracy reached 0.6 at epoch 2!!! #skorch #pytorch\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, by passing `callbacks__tweet__min_accuracy=0.6`, we changed that parameter. The same can be achieved by calling the `set_params` method with the corresponding arguments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.set_params(callbacks__tweet__min_accuracy=0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "     11        \u001b[36m0.5513\u001b[0m       0.7750        \u001b[35m0.5405\u001b[0m  0.0136\n",
      "     12        0.5612       0.7800        \u001b[35m0.5361\u001b[0m  0.0133\n",
      "     13        \u001b[36m0.5473\u001b[0m       \u001b[32m0.7950\u001b[0m        \u001b[35m0.5303\u001b[0m  0.0159\n",
      "     14        \u001b[36m0.5304\u001b[0m       0.7900        \u001b[35m0.5241\u001b[0m  0.0162\n",
      "     15        \u001b[36m0.5088\u001b[0m       0.7850        \u001b[35m0.5198\u001b[0m  0.0170\n",
      "     16        0.5373       0.7800        \u001b[35m0.5168\u001b[0m  0.0170\n",
      "     17        0.5377       0.7750        0.5179  0.0169\n",
      "     18        0.5257       0.7700        0.5177  0.0171\n",
      "     19        0.5150       0.7700        \u001b[35m0.5132\u001b[0m  0.0169\n",
      "     20        0.5136       0.7450        \u001b[35m0.5116\u001b[0m  0.0167\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "*tweet* Accuracy reached 0.75 at epoch 11!!! #skorch #pytorch\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Working with different data types"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Working with `Dataset`s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We encourage you to not pass `Dataset`s to `net.fit` but to let skorch handle `Dataset`s internally. Nonetheless, there are situations where passing `Dataset`s to `net.fit` is hard to avoid (e.g. if you want to load the data lazily during the training). This is supported by skorch but may have some unwanted side-effects relating to sklearn. For instance, `Dataset`s cannot split into train and validation in a stratified fashion without explicit knowledge of the classification targets."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we show what happens when you try to fit with `Dataset` and the stratified split fails:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "        \n",
    "        assert len(X) == len(y)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.X[i], self.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = make_classification(1000, 20, n_informative=10, random_state=0)\n",
    "X = X.astype(np.float32)\n",
    "dataset = MyDataset(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(ClassifierModule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: Stratified CV requires explicitely passing a suitable y.\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    net.fit(dataset, y=None)\n",
    "except ValueError as e:\n",
    "    print(\"Error:\", e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.train_split.stratified"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, the stratified split fails since `y` is not known. There are two solutions to this:\n",
    "\n",
    "* turn off stratified splitting ( `net.train_split.stratified=False`) \n",
    "* pass `y` explicitly (if possible), even if it is implicitely contained in the `Dataset`\n",
    "\n",
    "The second solution is shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Re-initializing module.\n",
      "Re-initializing optimizer.\n",
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.6938\u001b[0m       \u001b[32m0.4650\u001b[0m        \u001b[35m0.6984\u001b[0m  0.0154\n",
      "      2        0.6975       0.4650        \u001b[35m0.6977\u001b[0m  0.0141\n",
      "      3        0.6938       0.4600        \u001b[35m0.6970\u001b[0m  0.0130\n",
      "      4        \u001b[36m0.6923\u001b[0m       \u001b[32m0.4700\u001b[0m        \u001b[35m0.6964\u001b[0m  0.0137\n",
      "      5        \u001b[36m0.6921\u001b[0m       \u001b[32m0.4800\u001b[0m        \u001b[35m0.6959\u001b[0m  0.0135\n",
      "      6        \u001b[36m0.6878\u001b[0m       \u001b[32m0.5000\u001b[0m        \u001b[35m0.6954\u001b[0m  0.0138\n",
      "      7        0.6901       0.4950        \u001b[35m0.6948\u001b[0m  0.0130\n",
      "      8        0.6884       0.4900        \u001b[35m0.6944\u001b[0m  0.0137\n",
      "      9        0.6896       0.4900        \u001b[35m0.6940\u001b[0m  0.0130\n",
      "     10        \u001b[36m0.6870\u001b[0m       0.4850        \u001b[35m0.6936\u001b[0m  0.0130\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(dataset, y=y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Working with dicts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### The standard case"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "skorch has built-in support for dictionaries as data containers. Here we show a somewhat contrived example of how to use dicts, but it should get the point across. First we create data and put it into a dictionary `X_dict` with two keys `X0` and `X1`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = make_classification(1000, 20, n_informative=10, random_state=0)\n",
    "X = X.astype(np.float32)\n",
    "X0, X1 = X[:, :10], X[:, 10:]\n",
    "X_dict = {'X0': X0, 'X1': X1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When skorch passes the dict to the pytorch module, it will pass the data as keyword arguments to the forward call. That means that we should accept the two keys `XO` and `X1` in the forward method, as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierWithDict(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            num_units0=50,\n",
    "            num_units1=50,\n",
    "            nonlin=F.relu,\n",
    "            dropout=0.5,\n",
    "    ):\n",
    "        super(ClassifierWithDict, self).__init__()\n",
    "        self.num_units0 = num_units0\n",
    "        self.num_units1 = num_units1\n",
    "        self.nonlin = nonlin\n",
    "        self.dropout = dropout\n",
    "\n",
    "        self.dense0 = nn.Linear(10, num_units0)\n",
    "        self.dense1 = nn.Linear(10, num_units1)\n",
    "        self.nonlin = nonlin\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.output = nn.Linear(num_units0 + num_units1, 2)\n",
    "\n",
    "    # NOTE: We accept X0 and X1, the keys from the dict, as arguments\n",
    "    def forward(self, X0, X1, **kwargs):\n",
    "        X0 = self.nonlin(self.dense0(X0))\n",
    "        X0 = self.dropout(X0)\n",
    "\n",
    "        X1 = self.nonlin(self.dense1(X1))\n",
    "        X1 = self.dropout(X1)\n",
    "\n",
    "        X = torch.cat((X0, X1), dim=1)\n",
    "        X = F.relu(X)\n",
    "        X = F.softmax(self.output(X), dim=-1)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As long as we keep this in mind, we are good to go."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(ClassifierWithDict, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierWithDict(\n",
       "    (dense0): Linear(in_features=10, out_features=50, bias=True)\n",
       "    (dense1): Linear(in_features=10, out_features=50, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (output): Linear(in_features=100, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X_dict, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Working with sklearn `FunctionTransformer` and `GridSearch`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import FunctionTransformer\n",
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "sklearn makes the assumption that incoming data should be numpy/sparse arrays or something similar. This clashes with the use of dictionaries. Unfortunately, it is sometimes impossible to work around that for now (for instance using skorch with `BaggingClassifier`). Other times, there are possibilities.\n",
    "\n",
    "When we have a preprocessing pipeline that involves `FunctionTransformer`, we have to pass the parameter `validate=False` so that sklearn allows the dictionary to pass through:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = Pipeline([\n",
    "    ('do-nothing', FunctionTransformer(validate=False)),\n",
    "    ('net', net),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "     steps=[('do-nothing', FunctionTransformer(accept_sparse=False, check_inverse=True, func=None,\n",
       "          inv_kw_args=None, inverse_func=None, kw_args=None,\n",
       "          pass_y='deprecated', validate=False)), ('net', <class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierWithDi... (dropout): Dropout(p=0.5)\n",
       "    (output): Linear(in_features=100, out_features=2, bias=True)\n",
       "  ),\n",
       "))])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe.fit(X_dict, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When trying a grid or randomized search, it is not that easy to pass a dict. If we try, we will get an error:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_grid = {\n",
    "    'net__module__num_units0': [10, 25, 50], \n",
    "    'net__module__num_units1': [10, 25, 50],\n",
    "    'net__lr': [0.01, 0.1],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search = GridSearchCV(pipe, param_grid, scoring='accuracy', verbose=1, cv=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found input variables with inconsistent numbers of samples: [2, 1000]\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    grid_search.fit(X_dict, y)\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The error above occurs because sklearn gets the length of the input data, which is 2 for the dict, and believes that is inconsistent with the length of the target (1000). \n",
    "\n",
    "To get around that, skorch provides a helper class called `SliceDict`. It allows us to wrap our dictionaries so that they also behave like a numpy array:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch.helper import SliceDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_slice_dict = SliceDict(X0=X0, X1=X1)  # X_slice_dict = SliceDict(**X_dict) would also work"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The SliceDict shows the correct length, shape, and is sliceable across values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of dict: 2, length of SliceDict: 1000\n",
      "Shape of SliceDict: (1000,)\n"
     ]
    }
   ],
   "source": [
    "print(\"Length of dict: {}, length of SliceDict: {}\".format(len(X_dict), len(X_slice_dict)))\n",
    "print(\"Shape of SliceDict: {}\".format(X_slice_dict.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Slicing the SliceDict slices across values: SliceDict(**{'X0': array([[-0.9658346 , -2.1890705 ,  0.16985609,  0.8138456 , -3.375209  ,\n",
      "        -2.1430597 , -0.39585084,  2.9419577 , -2.1910605 ,  1.2443967 ],\n",
      "       [-0.454767  ,  4.339768  , -0.48572844, -4.88433   , -2.8836503 ,\n",
      "         2.6097205 , -1.952876  , -0.09192174,  0.07970932, -0.08938338]],\n",
      "      dtype=float32), 'X1': array([[ 0.04351204, -0.5150961 , -0.86073655, -1.1097169 ,  0.31839254,\n",
      "        -0.8231973 , -1.056304  , -0.89645284,  0.3759244 , -1.0849651 ],\n",
      "       [-0.60726726, -1.0674309 ,  0.48804346, -0.50230557,  0.55743027,\n",
      "         1.01592   , -1.9953582 ,  2.9030426 , -0.9739298 ,  2.1753323 ]],\n",
      "      dtype=float32)})\n"
     ]
    }
   ],
   "source": [
    "print(\"Slicing the SliceDict slices across values: {}\".format(X_slice_dict[:2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this, we can call `GridSearch` just as expected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 3 folds for each of 18 candidates, totalling 54 fits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
      "[Parallel(n_jobs=1)]: Done  54 out of  54 | elapsed:   11.3s finished\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=3, error_score='raise-deprecating',\n",
       "       estimator=Pipeline(memory=None,\n",
       "     steps=[('do-nothing', FunctionTransformer(accept_sparse=False, check_inverse=True, func=None,\n",
       "          inv_kw_args=None, inverse_func=None, kw_args=None,\n",
       "          pass_y='deprecated', validate=False)), ('net', <class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierWithDi... (dropout): Dropout(p=0.5)\n",
       "    (output): Linear(in_features=100, out_features=2, bias=True)\n",
       "  ),\n",
       "))]),\n",
       "       fit_params=None, iid='warn', n_jobs=None,\n",
       "       param_grid={'net__module__num_units0': [10, 25, 50], 'net__module__num_units1': [10, 25, 50], 'net__lr': [0.01, 0.1]},\n",
       "       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
       "       scoring='accuracy', verbose=1)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_search.fit(X_slice_dict, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.754,\n",
       " {'net__lr': 0.1,\n",
       "  'net__module__num_units0': 50,\n",
       "  'net__module__num_units1': 50})"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_search.best_score_, grid_search.best_params_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multiple return values from `forward`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Often, we want our `Module.forward` method to return more than just one value. There can be several reasons for this. Maybe, the criterion requires not one but several outputs. Or perhaps we want to inspect intermediate values to learn more about our model (say inspecting attention in a sequence-to-sequence model). Fortunately, `skorch` makes it easy to achieve this. In the following, we demonstrate how to handle multiple outputs from the `Module`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate this, we implement a very simple autoencoder. It consists of an encoder that reduces our input of 20 units to 5 units using two linear layers, and a decoder that tries to reconstruct the original input, again using two linear layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementing a simple autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch import NeuralNetRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, num_units=5):\n",
    "        super().__init__()\n",
    "        self.num_units = num_units\n",
    "        \n",
    "        self.encode = nn.Sequential(\n",
    "            nn.Linear(20, 10),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(10, self.num_units),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        \n",
    "    def forward(self, X):\n",
    "        encoded = self.encode(X)\n",
    "        return encoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, num_units):\n",
    "        super().__init__()\n",
    "        self.num_units = num_units\n",
    "        \n",
    "        self.decode = nn.Sequential(\n",
    "            nn.Linear(self.num_units, 10),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(10, 20),\n",
    "        )\n",
    "        \n",
    "    def forward(self, X):\n",
    "        decoded = self.decode(X)\n",
    "        return decoded"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The autoencoder module below actually returns a tuple of two values, the decoded input and the encoded input. This way, we cannot only use the decoded input to calculate the normal loss but also have access to the encoded state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AutoEncoder(nn.Module):\n",
    "    def __init__(self, num_units):\n",
    "        super().__init__()\n",
    "        self.num_units = num_units\n",
    "\n",
    "        self.encoder = Encoder(num_units=self.num_units)\n",
    "        self.decoder = Decoder(num_units=self.num_units)\n",
    "        \n",
    "    def forward(self, X):\n",
    "        encoded = self.encoder(X)\n",
    "        decoded = self.decoder(encoded)\n",
    "        return decoded, encoded  # <- return a tuple of two values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the module's `forward` method returns two values, we have to adjust our objective to do the right thing with those values. If we don't do this, the criterion wouldn't know what to do with the two values and would raise an error.\n",
    "\n",
    "One strategy would be to only use the decoded state for the loss and discard the encoded state. For this demonstration, we have a different plan: We would like the encoded state to be sparse. Therefore, we add an L1 loss of the encoded state to the reconstruction loss. This way, the net will try to reconstruct the input as accurately as possible while keeping the encoded state as sparse as possible.\n",
    "\n",
    "To implement this, the right method to override is called `get_loss`, which is where `skorch` computes and returns the loss. It gets the prediction (our tuple) and the target as input, as well as other arguments and keywords that we pass through. We create a subclass of `NeuralNetRegressor` that overrides said method and implements our idea for the loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AutoEncoderNet(NeuralNetRegressor):\n",
    "    def get_loss(self, y_pred, y_true, *args, **kwargs):\n",
    "        decoded, encoded = y_pred  # <- unpack the tuple that was returned by `forward`\n",
    "        loss_reconstruction = super().get_loss(decoded, y_true, *args, **kwargs)\n",
    "        loss_l1 = 1e-3 * torch.abs(encoded).sum()\n",
    "        return loss_reconstruction + loss_l1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Note*: Alternatively, we could have used an unaltered `NeuralNetRegressor` but implement a custom criterion that is responsible for unpacking the tuple and computing the loss."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training the autoencoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that everything is ready, we train the model as usual. We initialize our net subclass with the `AutoEncoder` module and call the `fit` method with `X` both as input and as target (since we want to reconstruct the original data):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = AutoEncoderNet(\n",
    "    AutoEncoder,\n",
    "    module__num_units=5,\n",
    "    lr=0.3,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_loss     dur\n",
      "-------  ------------  ------------  ------\n",
      "      1        \u001b[36m3.8328\u001b[0m        \u001b[32m3.7855\u001b[0m  0.0233\n",
      "      2        \u001b[36m3.6989\u001b[0m        \u001b[32m3.7111\u001b[0m  0.0244\n",
      "      3        \u001b[36m3.6417\u001b[0m        \u001b[32m3.6707\u001b[0m  0.0259\n",
      "      4        \u001b[36m3.6101\u001b[0m        \u001b[32m3.6463\u001b[0m  0.0209\n",
      "      5        \u001b[36m3.5914\u001b[0m        \u001b[32m3.6310\u001b[0m  0.0226\n",
      "      6        \u001b[36m3.5799\u001b[0m        \u001b[32m3.6212\u001b[0m  0.0242\n",
      "      7        \u001b[36m3.5725\u001b[0m        \u001b[32m3.6144\u001b[0m  0.0307\n",
      "      8        \u001b[36m3.5672\u001b[0m        \u001b[32m3.6090\u001b[0m  0.0347\n",
      "      9        \u001b[36m3.5627\u001b[0m        \u001b[32m3.6036\u001b[0m  0.0239\n",
      "     10        \u001b[36m3.5570\u001b[0m        \u001b[32m3.5963\u001b[0m  0.0264\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class '__main__.AutoEncoderNet'>[initialized](\n",
       "  module_=AutoEncoder(\n",
       "    (encoder): Encoder(\n",
       "      (encode): Sequential(\n",
       "        (0): Linear(in_features=20, out_features=10, bias=True)\n",
       "        (1): ReLU()\n",
       "        (2): Linear(in_features=10, out_features=5, bias=True)\n",
       "        (3): ReLU()\n",
       "      )\n",
       "    )\n",
       "    (decoder): Decoder(\n",
       "      (decode): Sequential(\n",
       "        (0): Linear(in_features=5, out_features=10, bias=True)\n",
       "        (1): ReLU()\n",
       "        (2): Linear(in_features=10, out_features=20, bias=True)\n",
       "      )\n",
       "    )\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Voilà, the model was trained using our custom loss function that makes use of both predicted values."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extracting the decoder and the encoder output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes, we may wish to inspect all the values returned by the `foward` method of the module. There are several ways to achieve this. In theory, we can always access the module directly by using the `net.module_` attribute. However, this is unwieldy, since this completely shortcuts the prediction loop, which takes care of important steps like casting `numpy` arrays to `pytorch` tensors and batching.\n",
    "\n",
    "Also, we cannot use the `predict` method on the net. This method will only return the first output from the forward method, in this case the decoded state. The reason for this is that `predict` is part of the `sklearn` API, which requires there to be only one output. This is shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 20)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = net.predict(X)\n",
    "y_pred.shape  # only the decoded state is returned"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, the net itself provides two methods to retrieve all outputs. The first one is the `net.forward` method, which retrieves *all* the predicted batches from the `Module.forward` and concatenates them. Use this to retrieve the complete decoded and encoded state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1000, 20]), torch.Size([1000, 5]))"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decoded_pred, encoded_pred = net.forward(X)\n",
    "decoded_pred.shape, encoded_pred.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The other method is called `net.forward_iter`. It is similar to `net.forward` but instead of collecting all the batches, this method is lazy and only yields one batch at a time. This can be especially useful if the output doesn't fit into memory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([128, 20]), torch.Size([128, 5]))"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for decoded_pred, encoded_pred in net.forward_iter(X):\n",
    "    # do something with each batch\n",
    "    break\n",
    "decoded_pred.shape, encoded_pred.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's make sure that our initial goal of having a sparse encoded state was met. We check how many activities are close to zero:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.8781)"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.isclose(encoded_pred, torch.zeros_like(encoded_pred)).float().mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we had hoped, the encoded state is quite sparse, with the majority of outpus being 0."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
